from torch import Tensor
from torch.utils.data import Dataset


class MyDataset(Dataset):
    def __init__(self, x: Tensor, y: Tensor):
        super(MyDataset, self).__init__()
        self.X = x
        self.Y = y

    def __getitem__(self, index):
        return self.X[index], self.Y[index]

    def __len__(self):
        return self.X.shape[0]
